-
Notifications
You must be signed in to change notification settings - Fork 4
ENH: introduce NEP 50 "weak scalars" #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Make python scalars "weak": in type promotion, they do not type promote arrays: (np.int8(3) + 4).dtype == int8 - Note that array scalars (np.int8(3) etc) are 0D arrays, so they are not weak. - Converting a weak scalar to an array (asarray(3) etc) makes it not weak. - Scalars are only weak in ufuncs. In places like `np.dot([1, 2, 3], 4.0)`, the result is float64.
Grrr, no. There's more to NEP 50. Converting to draft for now. |
CI's green, finally! |
Based on an off-line discussion with @lezcano , this PR now:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A 70 line implementation of NEP-50. How cool is that?
torch_np/_dtypes_impl.py
Outdated
|
||
# detect uint overflow: in PyTorch, uint8(-1) wraps around to 255, | ||
# while NEP50 mandates an exception. | ||
if weak_.dtype == torch.uint8 and weak_.item() != weak: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this happen just for uint8
or for any int dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, using .item() is not kosher. Let's do 0 <= weak < 2**8 before doing the as_tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit different: checking the weak
's value does not detect uint(100) + 200
. However, numpy warns not raises, so we shouldn't raise either. As discussed, this PR now does what numpy does, sans RuntimeWarnings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, it's a bit annoying that this check is just done for ints. If it were done for all dtypes, we could create the tensors with torch.full
, which does check if the number fits in the given type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that is also a valid choice, numpy already gives a customizable warning anyway if you overflow to inf
, so this just seemed the easier/OK thing. ints don't overflow graciously though...
Perhaps @seberg could have a look at this one as well? |
The truly complicated stuff is if you have more than 2 operands, without that, this seems fine (no you cannot just do a I am not actually sure about that (OTOH, I do feel that beyond ufuncs the issue is probably small enough that missing a few functions isn't the end of the world.) |
OK, two updates:
|
Make python scalars "weak" [1], meaning that in type promotion, they do not type promote arrays:
(np.int8(3) + 4).dtype == int8
np.dot([1, 2, 3], 4.0)
, the result is float64.[1] https://numpy.org/neps/nep-0050-scalar-promotion.html
This is an alternative to gh-137, which it supersedes and closes.
NB: tests fail. The majority of failures are, I believe, in internals of `torch_np.testing` which are vendored from numpy. And that is incredibly messy, so they are due in for a face-lift. Am migrating them to rely on `torch.testing.assert_close`.